####################################################
#Author: Kelli Marquardt
#Purpose: Simulate Fake patient note dataset for replication package, Mis(sed) Diagnosis: Physician Decision Making and ADHD 

# Inputs:

# Outputs:
#- data/note_dat_fake_base.csv
####################################################


############################
#0 load required packages
############################
rm(list = ls(all.names = TRUE))

required.packages = c("dplyr", "MASS", "tidyr")

missing=setdiff(required.packages, rownames(installed.packages()))

if (length(missing) > 0) {
  message("Installing missing packages: ", paste(missing, collapse = ", "))
  install.packages(missing, dependencies = TRUE)
} else {
  message("All required packages are already installed.")
}
rm(missing, required.packages)

#load packages
library(MASS)
library(dplyr) 
library(tidyr)


####################################
# set seed for replication
seed_for_fake_data=12345
set.seed(seed_for_fake_data)
####################################


######################################
#0: define helper functions  
######################################

#force vector to be between lo and hi
clamp = function(x, lo, hi) pmin(pmax(x, lo), hi)

#random uniform with specific mean and bounds
runif_shifted_mean = function(n, lo, hi, target_mean = NA_real_) {
  u = runif(n, lo, hi)
  if (is.na(target_mean)) return(u)
  mid = (lo + hi) / 2
  delta = target_mean - mid
  clamp(u + delta, lo, hi)
}


############################################
# Step 1: build dataset basics (pat_id, visit_id, assigned_label)
############################################
#100 patients with 3 appt each 
#assigned_label=1,0, NA
#30 patients with all visits assigned_label=0
#10 patients with all visits assigned_label=1
#10 patients with all visits assigned_label=NA
#25 patients with with both assigned_label=1,NA
#25 patients with with both assigned_label=0,NA


# create base dataset: 100 patients × 3 visits
note_dat = expand_grid(
  pat_id   = 1:100,
  visit_id = 1:3
)

# assign labels by patient
note_dat = note_dat %>%
  mutate(
    assigned_label = case_when(
      # 30 patients: all 0
      pat_id %in% 1:30 ~ 0,
      
      # 10 patient: all 1
      pat_id %in% 31:40 ~ 1,
      
      # 10 patient: all NA
      pat_id %in% 41:50 ~ NA,
      
      # 25 patients: NA for 1 visit, 1 for 2 visits
      pat_id %in% 51:75 ~ if_else(visit_id == 1, NA, 1),
      
      # 25 patients: NA for 1 visits, 0 for 2 visit
      pat_id %in% 76:100 ~ if_else(visit_id == 1, NA, 0)
    )
  )


############################################
# Step 2: add behav.val which is used for validation (assign randomly before notes)
############################################
  #=1 for 80% of those with assigned_label=1 and 20% of those with assigned_label=NA
  #=0 otherwise 
note_dat = note_dat%>%
  mutate(behav.val=case_when(is.na(assigned_label) ~ rbinom(n(), size=1, prob=.2), 
                             assigned_label==0 ~ 0,
                             T ~ rbinom(n(), size=1, prob=.8)))


############################################
# Step 3: add adhd_dx to random set of visits with assigned_label=1 
############################################
#=1 for 25% of visits with assigned_label=1 
#=0 otherwise 
note_dat = note_dat%>%
  mutate(adhd_dx=case_when(is.na(assigned_label) ~ 0, 
                             assigned_label==0 ~ 0,
                             T ~ rbinom(n(), size=1, prob=.25)))

############################################
# Step 4: add age 
############################################
#for each patient, get age at visit 2 (based on mean and variance in table 2)
#then subtract min(0,N(60days ,30day)) for visit 1 = random 
#then add max(0,N(60days ,30day)) for visit 3 = random 


note_dat = note_dat %>%
  group_by(pat_id) %>%
  mutate(
    # one baseline age per patient (shared across that patient's visits)
    age_v2 = runif_shifted_mean(1, 5, 18, 10.3)[1],
    
    # row-specific time shift (in years): ~60 days, sd ~30 days, bounded [1 day]
    dt = rnorm(n(), mean = 60/365, sd = 30/365))%>%
  ungroup()%>%
    mutate(
    age = case_when(
      visit_id == 1 ~ age_v2 - pmax(1/365,dt),
      visit_id == 2 ~ age_v2,
      TRUE          ~ age_v2 + pmax(1/365,dt)
    )
  ) %>%
  select(-c(age_v2, dt))



############################################
# Step 5: add male 
############################################
#for each patient, random male with p=.5

note_dat = note_dat%>%
  group_by(pat_id)%>%
  mutate(male=rbinom(1, size=1, prob=.5))%>%
  ungroup()

############################################
# Step 5: save
############################################

write.csv(note_dat, "data/note_dat_fake_base.csv", row.names = F)


#END OF SCRIPT

